{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "001a42c0",
   "metadata": {},
   "source": [
    "## Multivariate Timeseries Graph Neural Network (MTGNN)\n",
    "The MTGNN model is a combination of graph convolution (GC) and temporal convolution (TC) blocks combined with a graph structure learning layer.<br />\n",
    "\n",
    "The graph construction layer assumes the information flow between nodes in the graph<br />\n",
    "is strictly uni-directional and learns a skew-symmetric adjacency matrix, meaning that $ A = -A^T$. <br />\n",
    "This matrix $A$ is then put through a ReLU function to remove the negative weights.\n",
    "\n",
    "The model structure looks like:<br /><br />\n",
    "<pre>\n",
    "    &#8625; res_channels &#10141;                 &#8628; <br />\n",
    "... &#10141;    TC   &#10141;     GC|A   &#10141; GC_out + res_channels &#10141; ... &#10141; last_layer_out <br />\n",
    "... &#10141;         &#8627; TC_out + skip channels &#10141;            &#10565; ... &#10565;     &#8627; last_layer_out + skip_channels  <br />\n",
    "                                                                                   &#8681; <br />\n",
    "                                                                              end_channels  \n",
    "                                                                                   &#8681; <br />\n",
    "                                                                              out_channels <br />\n",
    "<br /></pre>\n",
    "The skip connections are summed up after every temporal convolution and carried until the end of the model and used to compute the final output. <br />\n",
    "Residual connections go from the input of a TC/GC block to the output. <br />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb588c1e",
   "metadata": {},
   "source": [
    "### MTGNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7e473b3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [06:53<00:00,  2.07s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEWCAYAAABsY4yMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAoAElEQVR4nO3deZxcVZn/8c9zq7qr987WWcgOkW1YM60iomyigg6IK44b6kxGZ1wYZRwdHWXcZhwdZ+QnPxURBARFUUYHxwUQQf2xJQISwhISkkDWXpL0vlTV8/vj3u5Ud6e7qzu5VZXK9/161eveunXr3qdvVz996pxzzzF3R0REyk9Q7ABERCQeSvAiImVKCV5EpEwpwYuIlCkleBGRMqUELyJSppTgRUqEmf3WzNzMLi12LFIelOAldma2KUpcryt2LKXAzM6KrsemUS/dCnwNWFf4qKQcJYsdgEg5M7MKdx/MZ193/3rc8cjhRSV4KTozu9jMHjKzTjPbbGZXmdmM6LVKM/u2me0ws34ze87M/id6zczsi9G2/mifX5nZ7HHOU2tmXzazDWbWZWaPmNk7oteONLOsmbWZWUW0bWlU0m6L4kia2cfM7Akz6zazdWa2Kuf4V0T732pmPzSzXuBto2I4C7g7ejp0fI9eG1FFY2bfjZ7fYGa/MLNeM/t1FNePoxjuM7PlOcc/wcx+bma7zKwl2m/JQfg1ySFICV6KyswuAH4CnBQtO4G/BX4Q7fJO4K+AVuA7wBrg9Oi1c4FPAJnotXuBE4H6cU53HXB5tP8PgRcAN5jZW919I/AHYBbwimj/N0fLW9x9APgc8CXAgJuBKuBbZvauUed5A3AUcCOwY9RrzwM/jtY7CatkvjZOvEPeDnQB7cB5wKPADGAjcFoUF2Y2P7oG5wG/B34LvB74lZmlJjmHlCEleCm2D0TLL7r7u4CzgDTwKjM7GqiIXn8MuAl4NzA32jb02jOECfsDwEJgy+iTmNlc4E3R0/Pc/T3AP0XPPxgtb4iWb4mWQwn+BjOznFj/H9ANrI2ev3/U6TYCL3b3Ve7+y9wX3P0ZYKgqpt3dL3P3y0bHO8pv3P1NwLej572ESXwo/lOj5TuAmYTXYwuwFWgBjgXOnuQcUoZUBy/FtixaPgHg7q1m1grMB5YSJt2zgIuASwAH7jSzi4FfA/+XMLENVXusBi4Eto9znl533xytPxktl0bLHwJXAq8zs2OBZuBpd7/fzJqAumi/d4869opRzx909/RkP/gUPBEt90TLZ9w9a2ad0fPaaLksWh4XPSaKUQ4DKsFLsW2KlscCRPXnc6Jtm4G0u78FaCBMWncSll5fDyQIS9UzCBPYDYRJ+a8mOE91Tp30MTnnwd33Aj8FGoGro9eGSvWthKV2gJPd3dzdCP+Gmkedq3+SnzkTLfP9+8tM8nzIpmh521B8UYwLCKuw5DCjBC+F9CUzuz/n8XLgqui1fzKz7xLWGyeBO9z9aeCtZvYEYf35hwnr2CEszZ4OPEtYdfMR4KU5r43g7rsIuyEC3GFm1wJfjJ7n9l4ZSugvI/y2cGP0fs+J9ddRw+/3CatjrpjSVYDnouUiM7vGzP5xiu8fz02EP/vFUWPzt8zszuh88w7SOeQQoioaKaSjRz2f5e7/bWZvBj4OvJGwIfFbhI2nAE8Rlp4vIGw83Q58HridsNS+nrCxdUa03zfZV/oe7T2Eye5iwnr2DcBX3f3mnH1+RdgwOh+4x91z6/M/BbQBlxI2fHYADwO35PnzA+Dum8zsK4TfNN4LPE7YeHtA3H2bmZ0JfAF4EXAGYV38VYTXRg4zpgk/RETKk6poRETKlBK8iEiZUoIXESlTSvAiImWqpHrRzJkzx5ctW1bsMEREDhlr1qxpdfem/b1WUgl+2bJlrF69uthhiIgcMsxs83ivqYpGRKRMKcGLiJQpJXgRkTKlBC8iUqaU4EVEypQSvIhImVKCFxEpU2WR4K+8az33PN1S7DBEREpKWST4b/x2A394RsNdi4jkijXBm9kMM7vVzJ40syfM7CVxnCcwyGQ1rr2ISK64hyr4GvBLd3+jmVUCNXGcJAhMCV5EZJTYEryZNQIvJ5zeDHcfAAbiOFciMDQzlYjISHFW0SwHWoDrzOzhaHLh2tE7mdkqM1ttZqtbWqbXUBqYkVGCFxEZIc4EnwRWAt9w91OBbsKJlUdw96vdvdndm5ua9jvi5aQCM1RDIyIyUpwJ/nngeXd/IHp+K2HCP+gCg6wyvIjICLEleHffATxnZsdEm84F1sVxroQaWUVExoi7F80HgZuiHjQbgXfHcRJV0YiIjBVrgnf3R4DmOM8BEASQVSOriMgIZXEna8JMCV5EZJSySPCBqQ5eRGS08kjwgUrwIiKjlUWCT5iRzRY7ChGR0lIWCd4M3ckqIjJKWSR4jUUjIjJWWSR4NbKKiIxVHgk+MDLK7yIiI5RFgk8YqqIRERmlLBK8qmhERMYqjwSvfvAiImOUR4I31A9eRGSUskjwiUAzOomIjFYWCT7QYGMiImOUT4JXI6uIyAhlkeATgSb8EBEZrSwSfGCom6SIyChlkuBVBy8iMlpZJPiE+sGLiIxRFgled7KKiIxVHgk+MFSAFxEZqSwSfEITfoiIjFEWCV5VNCIiY5VHglcVjYjIGMk4D25mm4BOIAOk3b05jvOoH7yIyFixJvjI2e7eGucJ1E1SRGSs8qii0Y1OIiJjxJ3gHfi1ma0xs1X728HMVpnZajNb3dLSMq2TqJFVRGSsuBP8Ge6+Ejgf+Dsze/noHdz9andvdvfmpqamaZ1Eg42JiIwVa4J3963RchdwG/CiOM5jhoYLFhEZJbYEb2a1ZlY/tA68Elgbx7kSqoMXERkjzl4084DbzGzoPDe7+y/jOJGm7BMRGSu2BO/uG4GT4zp+LjPTpNsiIqOURTfJRICqaERERpkwwZtZwswuNLNjChXQdASmKhoRkdEmTPDungG+A7ykMOFMT2DhWDSuJC8iMiyfOvibgEvN7CFg+9BGd2+PLaopSgQGQNbDoYNFRCS/BP8hwjtS/5SzzfN8b0FE+Z1M1oeTvYjI4S6fJH0vYUIvWcFwCb6kwxQRKahJE7y7n1WAOA5IYErwIiKjTdpN0swazey7ZrYzelxrZo2FCC5fCdtXBy8iIqF8+sFfCbwTGIgelwL/FV9IUzdURaMRJUVE9sknwZ8P/Lu7L3b3xcCXgdfEG9bUDLWrasAxEZF9pnMna8ll0YQaWUVExsinF83/Av9gZn8ZPV8I3BBfSFMXDWimu1lFRHLkk+AvIyzpnx89vxH4+7gCmo6hRlbldxGRfSZM8GaWAP4ZuM7d31mYkKYuEVU0qZFVRGSffMaieR1wVEGimabhKholeBGRYflU0fwW+LSZpRg5Fs1P4gpqqlRFIyIyVj4J/t3R8spoaYQ9aRKxRDQNwVAVjTK8iMiwfBL8v8QexQHSUAUiImPl08jaANzu7ncXJqSpG+4Hrzp4EZFhZdHIGmgsGhGRMcqikTVQLxoRkTHKo5F1aCwa1cGLiAzLJ8F/lhIcfyaXxqIRERkrnwk/rihAHAdEwwWLiIw1biOrmf3RzM4zs9poko9jo+0Xm1neE26bWcLMHjaz2w9GwPujRlYRkbEm6kVzCjATqCKc5OOIaHslMJUZnT4MPDGN2PKWUD94EZExJhsP/oAyppktIpwc5JoDOc5khhpZVUUjIrLPZHXw7wHOI0z0HzCz1wEvmMLx/wv4GFA/3g5mtgpYBbBkyZIpHHqfQI2sIiJjTJbgX5Wz/rqc9UkzqZm9Ftjl7mvM7Kzx9nP3q4GrAZqbm6eVoffdyTqdd4uIlKeJEvzZB3jslwIXmtkFhPX4DWb2PXd/+wEedwz1gxcRGWvcBO/u9xzIgd39E8AnAKIS/OVxJHfIuZNVCV5EZNh0Jt0uOcPdJNXIKiIyLJ87WQ+Yu/+WcEybWOy7kzWuM4iIHHrKqgSvbpIiIvuMW4I3s09P9EZ3/+zBD2d6hmZ0ctXBi4gMm6iK5oqcdSccRXJoHcJByEpCQo2sIiJjTJTg3xQtzwbOBP6TsErnw8B9Mcc1JaYqGhGRMSbqJvljADP7CvAFd782em6Ed6eWjKFGVhXgRUT2yacXTQr4TDSuTEA4AUhJNc4mVIIXERkjnwR/OeFgYUONrn2EY9SUDNOdrCIiY+Qz4cfNZnYncFq06X533xVvWFOjGZ1ERMbKt6rlhcA5wDPAK83s5PhCmrrE8IxORQ5ERKSETJrgzewy4H+ADwLzgdcDX443rKlRFY2IyFj5lOAvA36U8/xOYGUs0UyTZnQSERkrnwQ/E3g053kNkIgnnOnRYGMiImPl04vmQeD90frlwBnAH2KLaBqGZnTKKL+LiAzLpwT/QaCXcKiCVwPbCattSsa+GZ2U4UVEhkxYgjezBHA0YcPqUB+Vp9w9E3dgU6EZnURExpqwBB8l8u8Aze6+LnqUVHIHzegkIrI/+dTB3wRcamYPEVbPAODu7bFFNUVDCV75XURkn3wS/IcIhwj+U842z/O9BbHvRidleBGRIfkk6XvZNwZ8SRqqg1eCFxHZJ5+xaM4qQBwHxMww04xOIiK5Jk3w0fjvlwAnAlXRZnf3j8YZ2FQlzNTIKiKSI58qmquA9zF22r6SSvCBGaqhERHZJ58bnS4Gbo7WPwzcDXwutoimKQh0o5OISK58x6L5XbS+HbgVWBVbRNOUMFMjq4hIjnyqaHZE++0gnNmpEuiY7E1mVkXYAycVvf9Wd//M9EOdmKpoRERGyqcE/ylgA2Gdex+wl/zGoukHznH3k4FTgFeb2WkTv2X6gsA0VIGISI58ukl+L+fpD/I9sId9FruipxXRI7YMHJjGohERyZVPN8nf7Gezu/u5ebw3AawBVgBXufsD+9lnFVGd/pIlSyYNeDyJQHXwIiK58qmDP2s/2/LKpNHAZKeY2QzgNjM7wd3XjtrnauBqgObm5mln6LAOXgleRGRIPgm+KWd9JnAFOYOO5cPd95jZ3YTjya+dbP/pCMzIatJtEZFh+TSyes6jA3gKeNdkbzKzpqjkjplVA+cBT0470kkkAt3JKiKSK58SfCtjq2SeyuN9C4Dro3r4APihu98+xfjyZmpkFREZYaqjSWaATcBXJnuTu/8JOHXakU1RIjDdySoikqMsRpOEocHGih2FiEjpyKeb5LUTvOzu/t6DGM+0qYpGRGSkfKpoLmXsSJK56yWR4FVFIyIyUj4J/ivAi4DPEjaWfgp4iCnc1VoI6gcvIjJSPgn+ncC/uPtvAMzsaOAf3f0fYo1sigIzMuoHLyIyLJ8E3wv8azRQmAEXAm2xRjUNCQ02JiIyQj43Ov0VYZJ/B/B2oCfaVlI02JiIyEj5dJO8y8yWAsdGm55094F4w5q6QIONiYiMMGEJPppwmyihLyAcbuDMAsQ1ZYEZKsCLiOwzbgnezO4i7Ab5CjN7L9GIj9Frn3H3zxcgvrxpyj4RkZEmKsGfAPw8Wn9ftPwccA/w13EGNR1BgAYbExHJMVGCbwTazKyRcEyZLe5+BXA9MLcAsU1JWEWjBC8iMmSiBL+JcB7W70X7/TLavoQS7SapKhoRkX0mSvD/DBwDvIZwyOD/iLZfAtwfc1xTZmYov4uI7DNuI6u7/yiaj/VI4Al37zKzJPCXwI5CBZivhPrBi4iMMGE/eHdvI6c6xt3TwKNxBzUdqqIRERkpnztZDwmqohERGalsEnzCNFywiEiusknwQaA6eBGRXPnM6HQMcDmwDEhEm93dz40xrikLzHSjk4hIjnyGC/5vwu6SuUouk2pGJxGRkfKpopkF/CfhYGNN0aMk72RVfhcR2SefBH8DsAKoIyy5Dz1KSqDBxkRERsiniuajhAn9tTnbPM/3FkxgaCwaEZEc+STpe5lGid3MFhOW/udF77/a3b821ePkKxGokVVEJFc+MzqdNc1jp4GPuvsfzaweWGNmd7j7umkeb0LhjE5xHFlE5NCUTzdJIxxg7ESgKtrs7v7Rid7n7tuB7dF6p5k9ASwE4knwqqIRERkhnyqaqwgn/HDAom1OWDefFzNbRjim/AP7eW0VsApgyZIl+R5yjIT6wYuIjJBPL5qLgZuj9Q8DdxPO7JQXM6sDfgxc5u4do19396vdvdndm5uamvI97P7Oo37wIiI58knwM4HfRevbgVuJStyTMbMKwuR+k7v/ZFoR5ikRqB+8iEiufKpodkT77QCuASqBMSXx0aK6++8QjiX/1QMJMh8aLlhEZKR8SvCfAjYQ1rn3AXuBy/J430uBdwDnmNkj0eOC6QY6GdOEHyIiI+TTTfJ7AGY2A1jq7v35HNjdf8++RtnYJcyU4EVEckxagjezZWb2EOG8rC8zs3vM7LPxhzY1qoMXERkpnyqabxL2XzcgS3hn6yVxBjUdprFoRERGyCfBnw58Pef5BmBRPOFMX8LC2iB1lRQRCeWT4FuBE6L1uYSl922xRTRNQVTbr3p4EZFQPt0kvw18IVq/KVp+PJ5wpi+IMnzGvbSGuRQRKZJ8etH8q5ltA14Tbbrd3W+IN6ypS0QJXgV4EZFQXoVdd78euD7mWA7IUBWNGlpFRELjJngzy0zwPnf3kqoJCYYaWVWEFxEBJi7BG+GokduAPQWJ5gAMJXiV4EVEQhP1orkO6AbmAI8BH3H3E4ceBYluCmbVVgLQ2jVQ5EhERErDuAne3d8LLAD+FlgM/NLMNpnZqwsV3FQsnFkNwNY9vUWORESkNEzYD97du4GNwLPAAGFpvr4AcU3ZwhlRgt+tBC8iAhMkeDP7pJmtB34DrAA+CCxw9x8VKripmNdQRTIwtu7pKXYoIiIlYaJG1s8RNrJuJLyb9ULgwnCYd9zdL4o/vPwlAmN+Y5VK8CIikcm6OhpwVPTIVZJdVRbOqFYdvIhIZKIEv7xgURwkC2dWc/+GtmKHISJSEsZN8O6+uZCBHAyLZlSzo6OPwUyWikQ+46iJiJSvssqCC2dWk3XYsbev2KGIiBRdeSX4GTWA+sKLiEC5JfiZ6gsvIjKkrBL8gsYqQCV4EREoswRfVZHgiMYqfv9MK65RJUXkMFdWCR7g/Wev4MFn27nt4a3FDkVEpKjKLsG/7UVLOHXJDD7/8yfY06ORJUXk8BVbgjeza81sl5mtjesc+xMExhcvPpE9PQN87a71hTy1iEhJibME/12gKEMLH7eggbe8cDE33reZDS1dxQhBRKToYkvw7n4v0B7X8SfzkfOOoaoiwav+817e8I3/x/qdncUKRUSkKIpeB29mq8xstZmtbmlpOWjHbapP8YNVp/HXLz+SzW09vOlb9/HIc3sO2vFFREqdxdmd0MyWAbe7+wn57N/c3OyrV68+6HFsaevh7d95gO7+NLd/6Azqqypwd+qrKg76uURECsnM1rh78/5eK3oJvhCWzK7h2ktfSN9ghrd9+wFe8sW7uPDrf6CrP13s0EREYnNYJHiAFXPr+Pc3nszm9h5WLp3J5rZuPv3TtXT0DZLJ6qYoESk/sVXRmNn3gbMI53HdCXzG3b8z0XviqqLJ1TeYoaoiwVfveJoro26Uxy9o4Ja/OU1VNiJyyJmoimayGZ2mzd3fGtexD0RVRQKAD52zgmWza3iuvZcrf7Oej/zwUf7PW08lGRi7ewaZU1dJND2hiMghKbYEX+qSiYDXr1wEQF1Vks/dvo5j//mXJAIjk3VOWtTIB85ewUtXzKE2ddheJhE5hClzAe956TKWzqrhyR0d9A1mqUkluOn+Lay6cQ3JwDhpUSPHLWiguz9NKplgVl0ls2oqSWedTDbLyiUzWbl0JlUVCfoGM1QkAhKBSv8iUlyxdpOcqkLUweerP53hgY3t3L+xjfs2tvFsazf1VUn6B7Ps7hlgMDPyulUmA5bPrmVDSxcLZlTxpdefxOkr5hQpehE5XExUB68EPw3uTmd/mmRUnfPQpnbu29DGUzu7OG5+Pb9et5NnW7s5fkEDf3PmkVx0ysL9Hmf73l7m1VcRqLQvItNUlEbWcmZmNOT0uDnn2Hmcc+y84eeXveJobn5wC7eueZ7LbnkEgP50lp7+NG954RKqKxP84MEtfOK2x3jvS5dz2XlH87FbH+VNzYs5+5i5I85191O7WDa7luVzagvys4lI+VAJPkZ9gxneds0DrNm8e3jbvIYUS2fV8uCmdmbUVLC3d5A/XzKT1Zt3M6cuxW8uP3P4n8f9G9t467fvZ0VTHb/48MtIJg6b2xZEJE+H/Z2sxVJVkeCadzZz6enL+O67X8gtq07jxIWNYPCO05bym4+exYKGKlZv3s2bmxfR1t3P529fx+a2btZu3ctHf/godakk63d1ceua54v944jIIUYl+CJbu3Uv929s471nLOeKnz3O9fdtHn4tGRi3/M1L+MLP1/Hc7l6++fY/58+XzixitCJSatTIeojIZp3Vm3ezqbWbuqokJy5sZPGsGtZu3cul1z1Ia9cArzlpAZ+/6ARm1lYWO1wRKQFK8GWguz/NNb97lq/fvZ6ZNZVc/qpjaKhK8rv1rVx0ykJetHwWADs7+hhIZ1k8q6bIEYtIISjBl5G1W/fyydse49Hn9wIM33m7cskMFs+q4ReP7aAyGfCj972E4xY0FDlaEYmbEnyZcXd++1QLmaxz2lGzufG+zfzq8R1saOni/BPmc+/TrQBccOICKpMBC2dUcf6JC5hTlypy5CJysCnBH2bWbetg1Y2r2dMzSH86w2DGaayu4Nzj5nLfhjZOWNjIO05byoyaCo6eVz88AJuIHHqU4A9j2azz9K5OPvPTx3n0+T2csWIOD23azd7eQSDsl/+u05dRl0ryyJY9PPzcHgbSWTJZJ5kwTlk8g7e8cDEve0FTkX8SEdkfJXgBwmQfBEZn3yCPPLeHvb2DXPeHTcM3Ys2sqeDFy2dTm0qSCKBnIMODz7bT0tXP5a88hlQyIDDj7GPn0lhdQWUyoLYyMe6wyj0DaZ5r7+WY+fWF/DFFDitK8DIud6e1awDHmV2bGjMKZu9Ahg9+/2HufGLnft+fSgbMqUtRm0qQdWiqS/FnRzTwyj+bz6d/upYnd3TyifOPZdXLj9T4+iIxUIKXA5LJOqs3tbN0di0D6Sy/f6aVgXSGvnSW9u4BWrv66enPYBZ201y7tYOBTJa6VJJTl8zgd+tbOe/4ebxh5SLWbG5nMOMsn1PLX5x8BLPUn1/kgCjBS0Ht6uzjvx/eyplHz+UFc+v4xj0b+OZvN9DZn6YyGVARGN0DGSqTAUtm1dCfzjC7NsVpR87mfWceyTO7unhiewcdfWlevHwWK5fMHB5x090ZyGSpCAKNwimCEryUgD09Azy+rYNTFs+gpjLB+l1d3PzAFnZ19lGZCNjR0ccDz7ZjwOg50KsqAhqqKhjIZOnqS5POOmZQV5mkNpWkrirJjOoKlsyqYensWhbOrKa2MsFAJsvOjj4efW4v2/f2MpDJ0lSXonnZLN6wchF/eKaV3sEM5x43lwWN1UW5LiIHSgleDglrt+7lJ3/cysmLGzntyNlUVSS4+8ldPL5tLx29aVIVAbWpZJi801k6+9N096fp6k/T1jXAlvYednT0MfojvWhmNUtn11CZCNi+t48nd3SOOfdJixp5yZGzmV1XyZb2HtZt62D9zi5m1VVy9Lx6XjC3jmRg9GeyVCYCKhMBGXe27u6lobqCY+fXc9yCBjp6B9nY2s0JCxs5orGK/nSWpvrUAXVFTWey7Orsp7s/jQMNVRX0DKTZtqePbXt6OXZBPSctmjHt48uhTQleDht9gxl2dvTRMxBOnTi7tnLMuD3rtnVwx7qdnL5iNjNrKrlj3U5+9fgO1m0L2w7qq5Icv6CBo+fV0949wFM7O3m2tZusOxWJgMFMdvifyNz6FB19g/QNZieMK5UMyLpTU5mkLpWkujIx3HYxs7aCpbNqqU0leKali9m1KWbXVvLY1r109A0ykM6O+VaTywzeffpy6quS9AykqUtVsLyplvqqJM/v7uX59h4cOHXxDBqi3k8LGqs4IvrWsnbbXmbXpVg4o5qWzn4SgTGzpmK4Udzdh9f7BjNUVSRwd1q6+nlqRydP7ehkU1s3K5rqOHXJTJrqU8yqrRzzTy2dyZIITI3tB5kSvEge3J2O3jQN1ckxSSidyRKYDdf7pzNZHKhIBGSyzqa2bp7c3klNKsFRc+r409Y97O4ZJJUM2LG3j67+NIEZvQNpOvvT9A5kmF1XSW0qSVvXAJtau+nqT3PU3DpaO/tp6x7gpIWNNNWnooRcTUN1Enfo7EuTSgYsnFlNU32Kq+/ZyC2rn8MMqpIJegczI2KvTARgMJAe+U+ovipJTWWCnR39AMyqraS9ewCA6ooEVRUBA+ksvYMZZtZUMpDJ0tmXZlZtJe7O7p7B4WPVpZJ09afHnDfjTlNdirqqJJtau2moruDoeXUkgyCsequuoKGqgsbqCmpTCXoGMtG3sgxd0Te0vsEMCxqrmduQwgj/oRlGYNDUUEVDVZKWzn76o58vm3W2d/TR0TvIi5fPoj+dZUt7Dycc0UhLVz+rN7WTCALm1FWyeFYNjdXhvAwbWro4orGaJbNrqE8lR/Qoq0gEVCajRyLsLtw9kGZufYpls2vHtAf1pzPsiq5rIjCSQfjZyWad/nSWxpoK6lNjP2fToQQvUuZ2dfbRUFVBVUWC/nSGDbu66RlIs2hmDXPrUwxmszy5vZO+wbD309bdvazdtpc9PQOce+w82rr7eXpnF8fOr8fM2L6nl/50lspkmIjbuwepSBhNdSm27e0F4Jh59Rw9v55j5tUzuy7Fc+09PLG9g/buAdq6B+jsSw/3rOrsS3NUUx1tXf1sbO3G3ekbzLK3d5COvkE6+/b9c6iuSIRtK6lwWZkM2Lq7l909A7iDE/4znuhbTWN1BTWVCbbv7Rs+5tA/vmPm1WMGLdE/0iHzGlK0dg2QmejA+5EILHxYuAwMOvvTY6oKR6uqCLsYVyYD5tSm+OH7XjKl8w7RlH0iZW5ufdXweiqZ4PgjRg40lwoSnLx4RqwxLJ5VM+1RTDNZp3cwQ1UyyHvmsmzW2dXZT2ff4HA7x1DpvjIZ4O5sae+huiLBnLoUG1u7qEtVML9x37XqGUjT1ZemujJBfVUFfYMZWjr76exLk83J0IOZLAPpLAPRMutQU5lg655eNrd1k8462ayTyULWw6FBFs6oxiz82TLuZLJOYEYqGbCnZ5CWrn5aO/vpz2SpT8WTimNN8Gb2auBrQAK4xt3/Lc7zicihKREYdVNMckFgzG+sGpGwc5kZS2fvm8t4xdyxd1TXVCapqdx33qqKRFkNtR3blH1mlgCuAs4HjgfeambHx3U+EREZKc45WV8EPOPuG919APgBcFGM5xMRkRxxJviFwHM5z5+Pto1gZqvMbLWZrW5paYkxHBGRw0ucCT4v7n61uze7e3NTk4akFRE5WOJM8FuBxTnPF0XbRESkAOJM8A8BLzCz5WZWCVwC/CzG84mISI7Yukm6e9rMPgD8irCb5LXu/nhc5xMRkZFi7Qfv7v8L/G+c5xARkf0rqaEKzKwF2DzFt80BWmMI52Ao1dgU19Qorqkr1djKMa6l7r7fHiolleCnw8xWjzcOQ7GVamyKa2oU19SVamyHW1xF7yYpIiLxUIIXESlT5ZDgry52ABMo1dgU19Qorqkr1dgOq7gO+Tp4ERHZv3IowYuIyH4owYuIlKlDOsGb2avN7Ckze8bMPl7EOBab2d1mts7MHjezD0fbrzCzrWb2SPS4oAixbTKzx6Lzr462zTKzO8xsfbScWeCYjsm5Jo+YWYeZXVas62Vm15rZLjNbm7Ntv9fIQldGn7k/mdnKAsf1ZTN7Mjr3bWY2I9q+zMx6c67dNwsc17i/OzP7RHS9njKzVxU4rltyYtpkZo9E2wt5vcbLD/F/xtz9kHwQDn+wATgSqAQeBY4vUiwLgJXRej3wNOEkJ1cAlxf5Om0C5oza9u/Ax6P1jwNfKvLvcQewtFjXC3g5sBJYO9k1Ai4AfgEYcBrwQIHjeiWQjNa/lBPXstz9inC99vu7i/4OHgVSwPLobzZRqLhGvf4fwKeLcL3Gyw+xf8YO5RJ8yUwo4u7b3f2P0Xon8AT7Gfu+hFwEXB+tXw+8rnihcC6wwd2negfzQePu9wLtozaPd40uAm7w0P3ADDNbUKi43P3X7j40Q/X9hKO0FtQ412s8FwE/cPd+d38WeIbwb7egcZmZAW8Gvh/HuScyQX6I/TN2KCf4vCYUKTQzWwacCjwQbfpA9DXr2kJXhUQc+LWZrTGzVdG2ee6+PVrfAcwrQlxDLmHkH12xr9eQ8a5RKX3u3kNY0huy3MweNrN7zOxlRYhnf7+7UrleLwN2uvv6nG0Fv16j8kPsn7FDOcGXHDOrA34MXObuHcA3gKOAU4DthF8RC+0Md19JODfu35nZy3Nf9PA7YVH6ylo4jPSFwI+iTaVwvcYo5jUaj5l9EkgDN0WbtgNL3P1U4CPAzWbWUMCQSvJ3l+OtjCxIFPx67Sc/DIvrM3YoJ/iSmlDEzCoIf3k3uftPANx9p7tn3D0LfJuYvppOxN23RstdwG1RDDuHvvJFy12FjityPvBHd98ZxVj065VjvGtU9M+dmV0KvBZ4W5QYiKpA2qL1NYR13UcXKqYJfnelcL2SwOuBW4a2Ffp67S8/UIDP2KGc4EtmQpGofu87wBPu/tWc7bn1ZhcDa0e/N+a4as2sfmidsIFuLeF1ele027uAnxYyrhwjSlXFvl6jjHeNfga8M+rpcBqwN+drduzM7NXAx4AL3b0nZ3uTmSWi9SOBFwAbCxjXeL+7nwGXmFnKzJZHcT1YqLgirwCedPfnhzYU8nqNlx8oxGesEK3IcT0IW5ufJvzv+8kixnEG4derPwGPRI8LgBuBx6LtPwMWFDiuIwl7MDwKPD50jYDZwF3AeuBOYFYRrlkt0AY05mwryvUi/CezHRgkrO9873jXiLBnw1XRZ+4xoLnAcT1DWD879Dn7ZrTvG6Lf8SPAH4G/KHBc4/7ugE9G1+sp4PxCxhVt/y7wvlH7FvJ6jZcfYv+MaagCEZEydShX0YiIyASU4EVEypQSvIhImVKCFxEpU0rwIiJlSgleylY0YqCPeuyJ4TxXRMd+48E+tsiBSBY7AJECeJhw5D6AgWIGIlJIKsHL4aCF8EaSO4G7zOzSqMT9vWgs8FYzu3xoZzP762iM7m4ze9DMzoi2V5rZv5rZ5mgs8XtHnedsC8dqbzGzN0XveWk0AFdftL3goxnK4UsJXg4HryRM8i2MHJbhbMJBsnYAXzazk83sHMIJkFsIB6FaAvzMzGYTjtn9ccI7ID9AeAdkrnOj4zUC/xZt+xjhHcV/B3wWaD3YP5zIeFRFI4eDB4BPReu7gROj9Wvd/VtmlgauAc4kTOgAn3H3O8xsCfBPhBMv/AXhLedv8XBc79G+6u5Xm9n7Ccc2gfA29NcS3pr+R8Jb0EUKQiV4ORy0uvud0WNNznYbtczlo5b5GJpsIs2+v61/JBzJcD3hmC2rLZpmTyRuKsHL4eAIM7sk53lFtHy3mW0BPhQ9v4dwAKiPAv9iZkcRJuXdhLMn/Q/QDNxiZrcCJ7n7ZZOc+xNAP2G1znOE09Y1AHsO8GcSmZQSvBwOTmXkZA9/Hy3vAv4WmA/8g7s/ChDNfPUx4KvAOuDv3b3NzP4NqAbeBpxDfsPeZoEPRudoI5wTdMsB/0QiedBoknLYiSbMuI4wqX+lyOGIxEZ18CIiZUoleBGRMqUSvIhImVKCFxEpU0rwIiJlSgleRKRMKcGLiJSp/w+KhEajJj5jFgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train MSE: 0.1186\n",
      "Test MSE: 0.7675\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch_geometric_temporal.nn.attention import MTGNN\n",
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "from torch_geometric_temporal.signal import temporal_signal_split\n",
    "\n",
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "lags = 10\n",
    "stride = 1\n",
    "epochs = 200\n",
    "batch_size = 32\n",
    "\n",
    "dataset = loader.get_dataset(lags)\n",
    "\n",
    "sample = next(iter(dataset))\n",
    "num_nodes = sample.x.size(0)\n",
    "\n",
    "train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.4)\n",
    "\n",
    "train_loader = DataLoader(list(train_dataset), batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(list(test_dataset), batch_size=batch_size, shuffle=False)\n",
    "\n",
    "### MODEL DEFINITION\n",
    "class AttentionGCN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AttentionGCN, self).__init__()\n",
    "\n",
    "        self.attention = MTGNN(\n",
    "            gcn_true=True,\n",
    "            build_adj=True,\n",
    "            gcn_depth=2,\n",
    "            num_nodes=num_nodes,\n",
    "            kernel_set=[5,5,5],\n",
    "            kernel_size=5,\n",
    "            dropout=0.2,\n",
    "            subgraph_size=num_nodes,\n",
    "            node_dim=1,\n",
    "            dilation_exponential=3,\n",
    "            conv_channels=9,\n",
    "            residual_channels=16,\n",
    "            skip_channels=16,\n",
    "            end_channels=32,\n",
    "            seq_length=lags,\n",
    "            in_dim=1,\n",
    "            out_dim=1,\n",
    "            layers=3,\n",
    "            propalpha=0.4,\n",
    "            tanhalpha=1,\n",
    "            layer_norm_affline=True,\n",
    "            xd=None\n",
    "        )\n",
    "\n",
    "    def forward(self, window):     \n",
    "        x = window.x.view(-1, 1, num_nodes, lags) # (batch, 1, lags, num_nodes)      \n",
    "        x = self.attention(x) # (batch, lags, num_nodes, 1)\n",
    "        return x.squeeze(-1).permute(0,2,1).flatten()\n",
    "    \n",
    "model = AttentionGCN()\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "### TRAIN\n",
    "model.train()\n",
    "\n",
    "loss_history = []\n",
    "for _ in tqdm(range(epochs)):\n",
    "    total_loss = 0\n",
    "    for i, window in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss = torch.mean((y_pred - window.y)**2)\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    total_loss /= i+1\n",
    "    loss_history.append(total_loss)\n",
    "\n",
    "### TEST \n",
    "model.eval()\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for i, window in enumerate(test_loader):\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss += torch.mean((y_pred - window.y)**2)\n",
    "    loss /= i+1\n",
    "\n",
    "### RESULTS PLOT\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "x_ticks = np.arange(1, epochs+1)\n",
    "ax.plot(x_ticks, loss_history)\n",
    "\n",
    "# figure labels\n",
    "ax.set_title('Loss over time', fontweight='bold')\n",
    "ax.set_xlabel('Epochs', fontweight='bold')\n",
    "ax.set_ylabel('Mean Squared Error', fontweight='bold')\n",
    "plt.show()\n",
    "\n",
    "print(\"Train MSE: {:.4f}\".format(total_loss))\n",
    "print(\"Test MSE: {:.4f}\".format(loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9b8547e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
